package org.zjvis.datascience.spark.util;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.io.Serializable;

/**
 * @description Spark缓存帮助工具类
 * @date 2021-12-23
 */
public class CacheUtil implements Serializable {

    private SparkSession sparkSession;

    public CacheUtil(SparkSession sparkSession) {
        this.sparkSession = sparkSession;
    }

    /**
     * tableName 不允许有"."
     *
     * @param tableName
     * @return
     */
    public String modifyCacheTableName(String tableName) {
        return String.format("%s_sample_cached", tableName.replaceAll("\\.", "_"));
    }

    public boolean isCacheTableExists(String tableName) {
        return sparkSession.catalog().tableExists(tableName) && sparkSession.catalog().isCached(tableName);
    }

    public boolean cacheTableForDataset(Dataset<Row> dataset, String tableName) {
        try {
            dataset.registerTempTable(tableName);
            sparkSession.catalog().cacheTable(tableName);
        } catch (Exception e) {
            return false;
        }
        return true;
    }

    public boolean unCacheTable(String tableName) {
        sparkSession.catalog().uncacheTable(tableName);
        return true;
    }
}
